package it.uniroma2.exp;

import it.uniroma2.exp.tools.AvgVarCalculator;
import it.uniroma2.util.tree.RandomTreeGenerator;
import it.uniroma2.util.tree.Tree;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

public class TestbedAnalyzer {

	//The location of the Question Classification trees file (in SVM Light format)
	public final String inputTreesQC = System.getProperty("input.file.qc", "/home/lorenzo/esperimenti/dtk/qc/correlationTestData/corr_test_data.dat");
	//The location of the Textual Entailment Recognition tree files folder
	private static final String rte_base_dir = System.getProperty("input.dir.rte", "/home/lorenzo/workspace/arte/data/experiments/preprocessing/CPW/");
	public final String[] inputTreePairsRTE = {
			rte_base_dir  + "RTE1_dev.xml"
			,rte_base_dir  + "RTE1_test.xml"
			,rte_base_dir  + "RTE2_dev.xml"
			,rte_base_dir  + "RTE2_test.xml"
			,rte_base_dir  + "RTE3_dev.xml"
			,rte_base_dir  + "RTE3_test.xml"
			,rte_base_dir  + "RTE4_test.xml"
			,rte_base_dir  + "RTE5_dev.xml"
			,rte_base_dir  + "RTE5_test.xml"
	};
	
	public class TestbedProperties {
		Map<String, Integer> labelsTerm = new HashMap<String, Integer>();
		Map<String, Integer> labelsNonTerm = new HashMap<String, Integer>();
		Map<Integer, Integer> nodeNumber = new TreeMap<Integer, Integer>();
		Map<Integer, Integer> sequenceLength = new TreeMap<Integer, Integer>();
		AvgVarCalculator branchingFactor = new AvgVarCalculator();
		AvgVarCalculator branchingFactorNonTerm = new AvgVarCalculator();
		AvgVarCalculator treeDepth = new AvgVarCalculator();
		
		public void printResults() {
			AvgVarCalculator nodeNum = new AvgVarCalculator();
			for (int num : nodeNumber.keySet())
				for (int i=0; i<nodeNumber.get(num); i++)
					nodeNum.addSample(num);
			AvgVarCalculator seqLen = new AvgVarCalculator();
			for (int num : sequenceLength.keySet())
				for (int i=0; i<sequenceLength.get(num); i++)
					seqLen.addSample(num);
			System.out.println(
					(labelsTerm.keySet().size() + labelsNonTerm.keySet().size()) + "\n"+//# Label
					labelsNonTerm.keySet().size() + "\n"+//# Label non-terminali
					labelsTerm.keySet().size() + "\n"+//# Label terminali
					nodeNum.getAvg() + "\n"+//# Nodi medio
					nodeNum.getMin() + "\n"+//# Nodi minimo
					nodeNum.getMax() + "\n"+//# Nodi massimo
					branchingFactor.getAvg() + "\n"+//Branching factor medio
					branchingFactor.getMin() + "\n"+//Branching factor minimo
					branchingFactor.getMax() + "\n"+//Branching factor massimo
					branchingFactorNonTerm.getAvg() + "\n"+//Branching factor medio (esclusi terminali)
					branchingFactorNonTerm.getMin() + "\n"+//Branching factor minimo (esclusi terminali)
					branchingFactorNonTerm.getMax() + "\n"+//Branching factor massimo (esclusi terminali)
					treeDepth.getAvg() + "\n"+//Profondità media
					treeDepth.getMin() + "\n"+//Profondità minima
					treeDepth.getMax() + "\n"+//Profondità massima
					seqLen.getAvg() + "\n"+//Lunghezza sequenze (# nodi terminali) media
					seqLen.getMin() + "\n"+//Lunghezza sequenze (# nodi terminali) minima
					seqLen.getMax() + "\n"//Lunghezza sequenze (# nodi terminali) massima
					);
			Map<Integer, Integer> termLabFreq = new TreeMap<Integer, Integer>();
			for (int freq : labelsTerm.values())
				if (termLabFreq.containsKey(freq))
					termLabFreq.put(freq, termLabFreq.get(freq)+1);
				else
					termLabFreq.put(freq, 1);
			Map<Integer, Integer> nonTermLabFreq = new TreeMap<Integer, Integer>();
			for (int freq : labelsNonTerm.values())
				if (nonTermLabFreq.containsKey(freq))
					nonTermLabFreq.put(freq, nonTermLabFreq.get(freq)+1);
				else
					nonTermLabFreq.put(freq, 1);
			for (int num : termLabFreq.keySet())
				System.out.println(num + "\t" + termLabFreq.get(num));
			System.out.println();
			for (int num : nonTermLabFreq.keySet())
				System.out.println(num + "\t" + nonTermLabFreq.get(num));
			System.out.println();
			for (int num : nodeNumber.keySet())
				System.out.println(num + "\t" + nodeNumber.get(num));
			System.out.println();
			for (int num : sequenceLength.keySet())
				System.out.println(num + "\t" + sequenceLength.get(num));
		}
	}
	
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		TestbedAnalyzer tba = new TestbedAnalyzer();
		tba.run();
	}
	
	public void run() {
		System.out.println("Artificial coprus");
		TestbedProperties artificialProperties = new TestbedProperties();
		RandomTreeGenerator rtg = new RandomTreeGenerator(0);
		for (int i=0; i<2000; i++) {
			System.out.print('.');
			if (i%100 == 0) System.out.println("" +i);
			String tree = rtg.generateRandomTree(6, 3, 15);
			try {
				addSample(Tree.fromPennTree(tree), artificialProperties);
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		System.out.println("done");
		
		System.out.println("QC");
		TestbedProperties qcProperties = new TestbedProperties();
		try {
			BufferedReader in = new BufferedReader(new FileReader(inputTreesQC));
			int count = 0;
			while (true) {
				String tree = in.readLine();
				if (tree == null)
					break;
				tree = tree.substring(tree.indexOf("|BT|")+4, tree.indexOf("|ET|")).trim(); 
				try {
					addSample(Tree.fromPennTree(tree), qcProperties);
				} catch (Exception e) {
					e.printStackTrace();
				}
				System.out.print('.');
				if (count%100 == 0) System.out.println("" +count);
				count++;
			}
			in.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		System.out.println("done");
		
		System.out.println("RTE");
		TestbedProperties rteProperties = new TestbedProperties();
		try {
			for (int n=0; n<inputTreePairsRTE.length; n++) {
				BufferedReader in = new BufferedReader(new FileReader(inputTreePairsRTE[n]));
				int count = 0;
				while (true) {
					String treet = in.readLine();
					if (treet == null)
						break;
					else
						treet = treet.trim();
					if (!treet.startsWith("<t>"))
						continue;
					String treeh = in.readLine().trim();
					if (!treeh.startsWith("<h>"))
						continue;
					treet = treet.substring(3, treet.indexOf("</t>")).trim(); 
					treeh = treeh.substring(3, treeh.indexOf("</h>")).trim(); 
					try {
						addSample(Tree.fromPennTree(treet), rteProperties);
						addSample(Tree.fromPennTree(treeh), rteProperties);
					} catch (Exception e) {
						e.printStackTrace();
					}
					System.out.print('.');
					if (count%100 == 0) System.out.println("" +count);
					count += 2;
				}
				in.close();
				System.out.println("end");
			}
		} catch (IOException e) {
			e.printStackTrace();
		}
		System.out.println("done");
		
		System.out.println("######## RESULTS ########");
		System.out.println("Artificial");
		artificialProperties.printResults();
		System.out.println("QC");
		qcProperties.printResults();
		System.out.println("RTE");
		rteProperties.printResults();
	}

	private void addSample(Tree t, TestbedProperties p) {
		int[] globalAttributes = analyzeTree(t, p);
		if (!p.nodeNumber.containsKey(globalAttributes[0]))
			p.nodeNumber.put(globalAttributes[0], 1);
		else
			p.nodeNumber.put(globalAttributes[0], p.nodeNumber.get(globalAttributes[0])+1);
		if (!p.sequenceLength.containsKey(globalAttributes[1]))
			p.sequenceLength.put(globalAttributes[1], 1);
		else
			p.sequenceLength.put(globalAttributes[1], p.sequenceLength.get(globalAttributes[1])+1);
		p.treeDepth.addSample(globalAttributes[2]);
	}
	
	private int[] analyzeTree(Tree t, TestbedProperties p) {
		if (t.isTerminal()) {
			if (!p.labelsTerm.containsKey(t.getRootLabel()))
				p.labelsTerm.put(t.getRootLabel(), 1);
			else
				p.labelsTerm.put(t.getRootLabel(), p.labelsTerm.get(t.getRootLabel())+1);
			return new int[]{1,1,0}; //node number, sequence length, tree depth
		}
		else {
			if (!p.labelsNonTerm.containsKey(t.getRootLabel()))
				p.labelsNonTerm.put(t.getRootLabel(), 1);
			else
				p.labelsNonTerm.put(t.getRootLabel(), p.labelsNonTerm.get(t.getRootLabel())+1);
			p.branchingFactor.addSample(t.getChildren().size());
			if (!t.isPreTerminal()) 
				p.branchingFactorNonTerm.addSample(t.getChildren().size());
			int[] globalAttributes = new int[]{0,0,0};
			for (Tree c : t.getChildren()) {
				int[] childAttributes = analyzeTree(c, p);
				globalAttributes[0] += childAttributes[0];
				globalAttributes[1] += childAttributes[1];
				if (globalAttributes[2] < childAttributes[2])
					globalAttributes[2] = childAttributes[2];
			}
			globalAttributes[0]++;
			globalAttributes[2]++;
			return globalAttributes;
		}
	}
}
